{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 6-3,使用GPU训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "深度学习的训练过程常常非常耗时，一个模型训练几个小时是家常便饭，训练几天也是常有的事情，有时候甚至要训练几十天。\n",
    "\n",
    "训练过程的耗时主要来自于两个部分，一部分来自数据准备，另一部分来自参数迭代。\n",
    "\n",
    "当数据准备过程还是模型训练时间的主要瓶颈时，我们可以使用更多进程来准备数据。\n",
    "\n",
    "当参数迭代过程成为训练时间的主要瓶颈时，我们通常的方法是应用GPU来进行加速。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install -q torchkeras \n",
    "!pip install -q  -U torchmetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.__version__ =  2.4.0\n",
      "torchkeras.__version__ =  4.0.0\n",
      "torchmetrics.__version__ =  1.4.1\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "import torchkeras \n",
    "import torchmetrics\n",
    "\n",
    "print(\"torch.__version__ = \",torch.__version__)\n",
    "print(\"torchkeras.__version__ = \",torchkeras.__version__)\n",
    "print(\"torchmetrics.__version__ = \",torchmetrics.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注：本节代码只能在有GPU的机器环境上才能正确执行。\n",
    "\n",
    "对于没有GPU的同学，可以使用Kaggle免费GPU。\n",
    "\n",
    "没有Kaggle使用经验的同学，推荐参考我的这个B站视频教程：\n",
    "\n",
    "《Kaggle免费GPU使用攻略》\n",
    "\n",
    "https://www.bilibili.com/video/BV1oa411u7uR\n",
    "\n",
    "下面是本文源码在kaggle中的链接，可以fork后运行。\n",
    "\n",
    "https://www.kaggle.com/lyhue1991/pytorch-gpu-examples\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pytorch中使用GPU加速模型非常简单，只要将模型和数据移动到GPU上。核心代码只有以下几行。\n",
    "\n",
    "```python\n",
    "# 定义模型\n",
    "... \n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device) # 移动模型到cuda\n",
    "\n",
    "# 训练模型\n",
    "...\n",
    "\n",
    "features = features.to(device) # 移动数据到cuda\n",
    "labels = labels.to(device) # 或者  labels = labels.cuda() if torch.cuda.is_available() else labels\n",
    "...\n",
    "```\n",
    "\n",
    "如果要使用多个GPU训练模型，也非常简单。只需要在将模型设置为数据并行风格模型。\n",
    "则模型移动到GPU上之后，会在每一个GPU上拷贝一个副本，并把数据平分到各个GPU上进行训练。核心代码如下。\n",
    "\n",
    "```python\n",
    "# 定义模型\n",
    "... \n",
    "\n",
    "if torch.cuda.device_count() > 1:\n",
    "    model = nn.DataParallel(model) # 包装为并行风格模型\n",
    "\n",
    "# 训练模型\n",
    "...\n",
    "features = features.to(device) # 移动数据到cuda\n",
    "labels = labels.to(device) # 或者 labels = labels.cuda() if torch.cuda.is_available() else labels\n",
    "...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 〇，GPU相关操作汇总"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:21:49.406397Z",
     "iopub.status.busy": "2023-08-02T09:21:49.405676Z",
     "iopub.status.idle": "2023-08-02T09:21:49.469074Z",
     "shell.execute_reply": "2023-08-02T09:21:49.467906Z",
     "shell.execute_reply.started": "2023-08-02T09:21:49.406358Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "if_cuda= True\n",
      "gpu_count= 1\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "# 1，查看gpu信息\n",
    "if_cuda = torch.cuda.is_available()\n",
    "print(\"if_cuda=\",if_cuda)\n",
    "\n",
    "gpu_count = torch.cuda.device_count()\n",
    "print(\"gpu_count=\",gpu_count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:21:50.949675Z",
     "iopub.status.busy": "2023-08-02T09:21:50.949297Z",
     "iopub.status.idle": "2023-08-02T09:21:55.584912Z",
     "shell.execute_reply": "2023-08-02T09:21:55.583660Z",
     "shell.execute_reply.started": "2023-08-02T09:21:50.949642Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n",
      "True\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "# 2，将张量在gpu和cpu间移动\n",
    "tensor = torch.rand((100,100))\n",
    "tensor_gpu = tensor.to(\"cuda:0\") # 或者 tensor_gpu = tensor.cuda()\n",
    "print(tensor_gpu.device)\n",
    "print(tensor_gpu.is_cuda)\n",
    "\n",
    "tensor_cpu = tensor_gpu.to(\"cpu\") # 或者 tensor_cpu = tensor_gpu.cpu() \n",
    "print(tensor_cpu.device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:21:55.587734Z",
     "iopub.status.busy": "2023-08-02T09:21:55.587050Z",
     "iopub.status.idle": "2023-08-02T09:21:55.597566Z",
     "shell.execute_reply": "2023-08-02T09:21:55.596260Z",
     "shell.execute_reply.started": "2023-08-02T09:21:55.587689Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "# 3，将模型中的全部张量移动到gpu上\n",
    "net = nn.Linear(2,1)\n",
    "print(next(net.parameters()).is_cuda)\n",
    "net.to(\"cuda:0\") # 将模型中的全部参数张量依次到GPU上，注意，无需重新赋值为 net = net.to(\"cuda:0\")\n",
    "print(next(net.parameters()).is_cuda)\n",
    "print(next(net.parameters()).device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:21:55.599678Z",
     "iopub.status.busy": "2023-08-02T09:21:55.599312Z",
     "iopub.status.idle": "2023-08-02T09:21:55.621575Z",
     "shell.execute_reply": "2023-08-02T09:21:55.620550Z",
     "shell.execute_reply.started": "2023-08-02T09:21:55.599640Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cpu\n",
      "[0]\n",
      "cuda:0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 4，创建支持多个gpu数据并行的模型\n",
    "linear = nn.Linear(2,1)\n",
    "print(next(linear.parameters()).device)\n",
    "\n",
    "model = nn.DataParallel(linear)\n",
    "print(model.device_ids)\n",
    "print(next(model.module.parameters()).device) \n",
    "\n",
    "#注意保存参数时要指定保存model.module的参数\n",
    "torch.save(model.module.state_dict(), \"model_parameter.pt\") \n",
    "\n",
    "linear = nn.Linear(2,1)\n",
    "linear.load_state_dict(torch.load(\"model_parameter.pt\")) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一，矩阵乘法范例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面分别使用CPU和GPU作一个矩阵乘法，并比较其计算效率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:21:58.132801Z",
     "iopub.status.busy": "2023-08-02T09:21:58.131981Z",
     "iopub.status.idle": "2023-08-02T09:21:58.138927Z",
     "shell.execute_reply": "2023-08-02T09:21:58.137799Z",
     "shell.execute_reply.started": "2023-08-02T09:21:58.132746Z"
    }
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import torch \n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:21:59.222992Z",
     "iopub.status.busy": "2023-08-02T09:21:59.222619Z",
     "iopub.status.idle": "2023-08-02T09:21:59.871529Z",
     "shell.execute_reply": "2023-08-02T09:21:59.870275Z",
     "shell.execute_reply.started": "2023-08-02T09:21:59.222960Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6000046730041504\n",
      "cpu\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "# 使用cpu\n",
    "a = torch.rand((10000,200))\n",
    "b = torch.rand((200,10000))\n",
    "tic = time.time()\n",
    "c = torch.matmul(a,b)\n",
    "toc = time.time()\n",
    "\n",
    "print(toc-tic)\n",
    "print(a.device)\n",
    "print(b.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:01.349894Z",
     "iopub.status.busy": "2023-08-02T09:22:01.349166Z",
     "iopub.status.idle": "2023-08-02T09:22:02.226856Z",
     "shell.execute_reply": "2023-08-02T09:22:02.224728Z",
     "shell.execute_reply.started": "2023-08-02T09:22:01.349856Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8443384170532227\n",
      "cuda:0\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "# 使用gpu\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "a = torch.rand((10000,200),device = device) #可以指定在GPU上创建张量\n",
    "b = torch.rand((200,10000)) #也可以在CPU上创建张量后移动到GPU上\n",
    "b = b.to(device) #或者 b = b.cuda() if torch.cuda.is_available() else b \n",
    "tic = time.time()\n",
    "c = torch.matmul(a,b)\n",
    "toc = time.time()\n",
    "print(toc-tic)\n",
    "print(a.device)\n",
    "print(b.device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 二，线性回归范例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-17T14:58:49.525724Z",
     "iopub.status.busy": "2022-07-17T14:58:49.525304Z",
     "iopub.status.idle": "2022-07-17T14:58:49.546334Z",
     "shell.execute_reply": "2022-07-17T14:58:49.544588Z",
     "shell.execute_reply.started": "2022-07-17T14:58:49.525694Z"
    }
   },
   "source": [
    "下面对比使用CPU和GPU训练一个线性回归模型的效率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1，使用CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:05.233030Z",
     "iopub.status.busy": "2023-08-02T09:22:05.232639Z",
     "iopub.status.idle": "2023-08-02T09:22:05.297141Z",
     "shell.execute_reply": "2023-08-02T09:22:05.296102Z",
     "shell.execute_reply.started": "2023-08-02T09:22:05.232997Z"
    }
   },
   "outputs": [],
   "source": [
    "# 准备数据\n",
    "n = 1000000 #样本数量\n",
    "\n",
    "X = 10*torch.rand([n,2])-5.0  #torch.rand是均匀分布 \n",
    "w0 = torch.tensor([[2.0,-3.0]])\n",
    "b0 = torch.tensor([[10.0]])\n",
    "Y = X@w0.t() + b0 + torch.normal( 0.0,2.0,size = [n,1])  # @表示矩阵乘法,增加正态扰动"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:06.677267Z",
     "iopub.status.busy": "2023-08-02T09:22:06.676242Z",
     "iopub.status.idle": "2023-08-02T09:22:06.685850Z",
     "shell.execute_reply": "2023-08-02T09:22:06.684746Z",
     "shell.execute_reply.started": "2023-08-02T09:22:06.677187Z"
    }
   },
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "class LinearRegression(nn.Module): \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = nn.Parameter(torch.randn_like(w0))\n",
    "        self.b = nn.Parameter(torch.zeros_like(b0))\n",
    "    #正向传播\n",
    "    def forward(self,x): \n",
    "        return x@self.w.t() + self.b\n",
    "        \n",
    "linear = LinearRegression() \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:08.718803Z",
     "iopub.status.busy": "2023-08-02T09:22:08.718085Z",
     "iopub.status.idle": "2023-08-02T09:22:13.949452Z",
     "shell.execute_reply": "2023-08-02T09:22:13.948304Z",
     "shell.execute_reply.started": "2023-08-02T09:22:08.718765Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'epoch': 0, 'loss': 258.9547119140625}\n",
      "{'epoch': 50, 'loss': 33.212669372558594}\n",
      "{'epoch': 100, 'loss': 9.038525581359863}\n",
      "{'epoch': 150, 'loss': 4.485360145568848}\n",
      "{'epoch': 200, 'loss': 4.017963409423828}\n",
      "{'epoch': 250, 'loss': 3.994182825088501}\n",
      "{'epoch': 300, 'loss': 3.993659734725952}\n",
      "{'epoch': 350, 'loss': 3.9936563968658447}\n",
      "{'epoch': 400, 'loss': 3.9936563968658447}\n",
      "{'epoch': 450, 'loss': 3.9936563968658447}\n",
      "time used: 5.222184896469116\n"
     ]
    }
   ],
   "source": [
    "# 训练模型\n",
    "optimizer = torch.optim.Adam(linear.parameters(),lr = 0.1)\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "def train(epoches):\n",
    "    tic = time.time()\n",
    "    for epoch in range(epoches):\n",
    "        optimizer.zero_grad()\n",
    "        Y_pred = linear(X) \n",
    "        loss = loss_fn(Y_pred,Y)\n",
    "        loss.backward() \n",
    "        optimizer.step()\n",
    "        if epoch%50==0:\n",
    "            print({\"epoch\":epoch,\"loss\":loss.item()})\n",
    "    toc = time.time()\n",
    "    print(\"time used:\",toc-tic)\n",
    "\n",
    "train(500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2，使用GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:13.952419Z",
     "iopub.status.busy": "2023-08-02T09:22:13.951355Z",
     "iopub.status.idle": "2023-08-02T09:22:13.998524Z",
     "shell.execute_reply": "2023-08-02T09:22:13.997457Z",
     "shell.execute_reply.started": "2023-08-02T09:22:13.952376Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.cuda.is_available() =  True\n",
      "X.device: cuda:0\n",
      "Y.device: cuda:0\n"
     ]
    }
   ],
   "source": [
    "# 准备数据\n",
    "n = 1000000 #样本数量\n",
    "\n",
    "X = 10*torch.rand([n,2])-5.0  #torch.rand是均匀分布 \n",
    "w0 = torch.tensor([[2.0,-3.0]])\n",
    "b0 = torch.tensor([[10.0]])\n",
    "Y = X@w0.t() + b0 + torch.normal( 0.0,2.0,size = [n,1])  # @表示矩阵乘法,增加正态扰动\n",
    "\n",
    "# 数据移动到GPU上\n",
    "print(\"torch.cuda.is_available() = \",torch.cuda.is_available())\n",
    "X = X.cuda()\n",
    "Y = Y.cuda()\n",
    "print(\"X.device:\",X.device)\n",
    "print(\"Y.device:\",Y.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:14.856923Z",
     "iopub.status.busy": "2023-08-02T09:22:14.856519Z",
     "iopub.status.idle": "2023-08-02T09:22:14.867761Z",
     "shell.execute_reply": "2023-08-02T09:22:14.866595Z",
     "shell.execute_reply.started": "2023-08-02T09:22:14.856887Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "if on cuda: True\n"
     ]
    }
   ],
   "source": [
    "# 定义模型\n",
    "class LinearRegression(nn.Module): \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = nn.Parameter(torch.randn_like(w0))\n",
    "        self.b = nn.Parameter(torch.zeros_like(b0))\n",
    "    #正向传播\n",
    "    def forward(self,x): \n",
    "        return x@self.w.t() + self.b\n",
    "        \n",
    "linear = LinearRegression() \n",
    "\n",
    "# 移动模型到GPU上\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "linear.to(device)\n",
    "\n",
    "#查看模型是否已经移动到GPU上\n",
    "print(\"if on cuda:\",next(linear.parameters()).is_cuda)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:21.232004Z",
     "iopub.status.busy": "2023-08-02T09:22:21.231614Z",
     "iopub.status.idle": "2023-08-02T09:22:21.785143Z",
     "shell.execute_reply": "2023-08-02T09:22:21.783907Z",
     "shell.execute_reply.started": "2023-08-02T09:22:21.231970Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'epoch': 0, 'loss': 153.66574096679688}\n",
      "{'epoch': 50, 'loss': 32.86173629760742}\n",
      "{'epoch': 100, 'loss': 9.03520679473877}\n",
      "{'epoch': 150, 'loss': 4.485783576965332}\n",
      "{'epoch': 200, 'loss': 4.018568515777588}\n",
      "{'epoch': 250, 'loss': 3.994813919067383}\n",
      "{'epoch': 300, 'loss': 3.9942924976348877}\n",
      "{'epoch': 350, 'loss': 3.994288921356201}\n",
      "{'epoch': 400, 'loss': 3.9942891597747803}\n",
      "{'epoch': 450, 'loss': 3.9942891597747803}\n",
      "time used: 0.5444216728210449\n"
     ]
    }
   ],
   "source": [
    "# 训练模型\n",
    "optimizer = torch.optim.Adam(linear.parameters(),lr = 0.1)\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "def train(epoches):\n",
    "    tic = time.time()\n",
    "    for epoch in range(epoches):\n",
    "        optimizer.zero_grad()\n",
    "        Y_pred = linear(X) \n",
    "        loss = loss_fn(Y_pred,Y)\n",
    "        loss.backward() \n",
    "        optimizer.step()\n",
    "        if epoch%50==0:\n",
    "            print({\"epoch\":epoch,\"loss\":loss.item()})\n",
    "    toc = time.time()\n",
    "    print(\"time used:\",toc-tic)\n",
    "    \n",
    "train(500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 三，图片分类范例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:24.911548Z",
     "iopub.status.busy": "2023-08-02T09:22:24.911130Z",
     "iopub.status.idle": "2023-08-02T09:22:24.917106Z",
     "shell.execute_reply": "2023-08-02T09:22:24.915927Z",
     "shell.execute_reply.started": "2023-08-02T09:22:24.911513Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "import torchvision \n",
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:26.656218Z",
     "iopub.status.busy": "2023-08-02T09:22:26.655272Z",
     "iopub.status.idle": "2023-08-02T09:22:27.893139Z",
     "shell.execute_reply": "2023-08-02T09:22:27.892010Z",
     "shell.execute_reply.started": "2023-08-02T09:22:26.656155Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8f4936d7cceb439494d46f71f02f7518",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/9912422 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "19d96f782ad148a59f1e1f4f98195d09",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/28881 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1428a4921c604cb783c8d17e444e0d55",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1648877 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7aa4643aca6a4f86be6ffdc0f4bb9bea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4542 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw\n",
      "\n",
      "60000\n",
      "10000\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor()])\n",
    "\n",
    "ds_train = torchvision.datasets.MNIST(root=\"mnist/\",train=True,download=True,transform=transform)\n",
    "ds_val = torchvision.datasets.MNIST(root=\"mnist/\",train=False,download=True,transform=transform)\n",
    "\n",
    "dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=2)\n",
    "dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=2)\n",
    "\n",
    "print(len(ds_train))\n",
    "print(len(ds_val))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:30.365184Z",
     "iopub.status.busy": "2023-08-02T09:22:30.364138Z",
     "iopub.status.idle": "2023-08-02T09:22:30.380062Z",
     "shell.execute_reply": "2023-08-02T09:22:30.378908Z",
     "shell.execute_reply.started": "2023-08-02T09:22:30.365148Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (dropout): Dropout2d(p=0.1, inplace=False)\n",
      "  (adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))\n",
      "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
      "  (linear1): Linear(in_features=64, out_features=32, bias=True)\n",
      "  (relu): ReLU()\n",
      "  (linear2): Linear(in_features=32, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "def create_net():\n",
    "    net = nn.Sequential()\n",
    "    net.add_module(\"conv1\",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))\n",
    "    net.add_module(\"pool1\",nn.MaxPool2d(kernel_size = 2,stride = 2))\n",
    "    net.add_module(\"conv2\",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))\n",
    "    net.add_module(\"pool2\",nn.MaxPool2d(kernel_size = 2,stride = 2))\n",
    "    net.add_module(\"dropout\",nn.Dropout2d(p = 0.1))\n",
    "    net.add_module(\"adaptive_pool\",nn.AdaptiveMaxPool2d((1,1)))\n",
    "    net.add_module(\"flatten\",nn.Flatten())\n",
    "    net.add_module(\"linear1\",nn.Linear(64,32))\n",
    "    net.add_module(\"relu\",nn.ReLU())\n",
    "    net.add_module(\"linear2\",nn.Linear(32,10))\n",
    "    return net\n",
    "\n",
    "net = create_net()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1，使用CPU进行训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:22:53.737134Z",
     "iopub.status.busy": "2023-08-02T09:22:53.736757Z",
     "iopub.status.idle": "2023-08-02T09:24:50.589851Z",
     "shell.execute_reply": "2023-08-02T09:24:50.588548Z",
     "shell.execute_reply.started": "2023-08-02T09:22:53.737099Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "================================================================================2023-08-02 09:22:53\n",
      "Epoch 1 / 3\n",
      "\n",
      "100%|██████████| 469/469 [00:35<00:00, 13.11it/s, train_acc=0.9, train_loss=0.312]   \n",
      "100%|██████████| 79/79 [00:03<00:00, 22.00it/s, val_acc=0.965, val_loss=0.111] \n",
      "<<<<<< reach best val_acc : 0.9646000266075134 >>>>>>\n",
      "\n",
      "================================================================================2023-08-02 09:23:33\n",
      "Epoch 2 / 3\n",
      "\n",
      "100%|██████████| 469/469 [00:35<00:00, 13.29it/s, train_acc=0.966, train_loss=0.109] \n",
      "100%|██████████| 79/79 [00:03<00:00, 22.50it/s, val_acc=0.975, val_loss=0.0814]\n",
      "<<<<<< reach best val_acc : 0.9749000072479248 >>>>>>\n",
      "\n",
      "================================================================================2023-08-02 09:24:12\n",
      "Epoch 3 / 3\n",
      "\n",
      "100%|██████████| 469/469 [00:34<00:00, 13.50it/s, train_acc=0.971, train_loss=0.095] \n",
      "100%|██████████| 79/79 [00:03<00:00, 23.03it/s, val_acc=0.964, val_loss=0.12]   \n",
      "<<<<<< val_acc without improvement in 1 epoch, early stopping >>>>>>\n"
     ]
    }
   ],
   "source": [
    "import os,sys,time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datetime \n",
    "from tqdm import tqdm \n",
    "\n",
    "import torch\n",
    "from torch import nn \n",
    "from copy import deepcopy\n",
    "from torchmetrics import Accuracy\n",
    "#注：多分类使用torchmetrics中的评估指标，二分类使用torchkeras.metrics中的评估指标\n",
    "\n",
    "def printlog(info):\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "    print(str(info)+\"\\n\")\n",
    "    \n",
    "\n",
    "net = create_net() \n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   \n",
    "metrics_dict = {\"acc\":Accuracy(task='multiclass',num_classes=10)}\n",
    "\n",
    "epochs = 3 \n",
    "ckpt_path='checkpoint.pt'\n",
    "\n",
    "#early_stopping相关设置\n",
    "monitor=\"val_acc\"\n",
    "patience=1\n",
    "mode=\"max\"\n",
    "\n",
    "history = {}\n",
    "\n",
    "for epoch in range(1, epochs+1):\n",
    "    printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n",
    "\n",
    "    # 1，train -------------------------------------------------  \n",
    "    net.train()\n",
    "    \n",
    "    total_loss,step = 0,0\n",
    "    \n",
    "    loop = tqdm(enumerate(dl_train), total =len(dl_train),file=sys.stdout)\n",
    "    train_metrics_dict = deepcopy(metrics_dict) \n",
    "    \n",
    "    for i, batch in loop: \n",
    "        \n",
    "        features,labels = batch\n",
    "        #forward\n",
    "        preds = net(features)\n",
    "        loss = loss_fn(preds,labels)\n",
    "        \n",
    "        #backward\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "            \n",
    "        #metrics\n",
    "        step_metrics = {\"train_\"+name:metric_fn(preds, labels).item() \n",
    "                        for name,metric_fn in train_metrics_dict.items()}\n",
    "        \n",
    "        step_log = dict({\"train_loss\":loss.item()},**step_metrics)\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        \n",
    "        step+=1\n",
    "        if i!=len(dl_train)-1:\n",
    "            loop.set_postfix(**step_log)\n",
    "        else:\n",
    "            epoch_loss = total_loss/step\n",
    "            epoch_metrics = {\"train_\"+name:metric_fn.compute().item() \n",
    "                             for name,metric_fn in train_metrics_dict.items()}\n",
    "            epoch_log = dict({\"train_loss\":epoch_loss},**epoch_metrics)\n",
    "            loop.set_postfix(**epoch_log)\n",
    "\n",
    "            for name,metric_fn in train_metrics_dict.items():\n",
    "                metric_fn.reset()\n",
    "                \n",
    "    for name, metric in epoch_log.items():\n",
    "        history[name] = history.get(name, []) + [metric]\n",
    "        \n",
    "\n",
    "    # 2，validate -------------------------------------------------\n",
    "    net.eval()\n",
    "    \n",
    "    total_loss,step = 0,0\n",
    "    loop = tqdm(enumerate(dl_val), total =len(dl_val),file=sys.stdout)\n",
    "    \n",
    "    val_metrics_dict = deepcopy(metrics_dict) \n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, batch in loop: \n",
    "\n",
    "            features,labels = batch\n",
    "            \n",
    "            #forward\n",
    "            preds = net(features)\n",
    "            loss = loss_fn(preds,labels)\n",
    "\n",
    "            #metrics\n",
    "            step_metrics = {\"val_\"+name:metric_fn(preds, labels).item() \n",
    "                            for name,metric_fn in val_metrics_dict.items()}\n",
    "\n",
    "            step_log = dict({\"val_loss\":loss.item()},**step_metrics)\n",
    "\n",
    "            total_loss += loss.item()\n",
    "            step+=1\n",
    "            if i!=len(dl_val)-1:\n",
    "                loop.set_postfix(**step_log)\n",
    "            else:\n",
    "                epoch_loss = (total_loss/step)\n",
    "                epoch_metrics = {\"val_\"+name:metric_fn.compute().item() \n",
    "                                 for name,metric_fn in val_metrics_dict.items()}\n",
    "                epoch_log = dict({\"val_loss\":epoch_loss},**epoch_metrics)\n",
    "                loop.set_postfix(**epoch_log)\n",
    "\n",
    "                for name,metric_fn in val_metrics_dict.items():\n",
    "                    metric_fn.reset()\n",
    "                    \n",
    "    epoch_log[\"epoch\"] = epoch           \n",
    "    for name, metric in epoch_log.items():\n",
    "        history[name] = history.get(name, []) + [metric]\n",
    "\n",
    "    # 3，early-stopping -------------------------------------------------\n",
    "    arr_scores = history[monitor]\n",
    "    best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n",
    "    if best_score_idx==len(arr_scores)-1:\n",
    "        torch.save(net.state_dict(),ckpt_path)\n",
    "        print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n",
    "             arr_scores[best_score_idx]))\n",
    "    if len(arr_scores)-best_score_idx>patience:\n",
    "        print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n",
    "            monitor,patience))\n",
    "        break \n",
    "    net.load_state_dict(torch.load(ckpt_path))\n",
    "    \n",
    "dfhistory = pd.DataFrame(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CPU每个Epoch大概40s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2，使用GPU进行训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:27:40.875092Z",
     "iopub.status.busy": "2023-08-02T09:27:40.874670Z",
     "iopub.status.idle": "2023-08-02T09:28:18.788065Z",
     "shell.execute_reply": "2023-08-02T09:28:18.786947Z",
     "shell.execute_reply.started": "2023-08-02T09:27:40.875053Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "================================================================================2023-08-02 09:27:40\n",
      "Epoch 1 / 5\n",
      "\n",
      "100%|██████████| 469/469 [00:14<00:00, 33.06it/s, train_acc=0.91, train_loss=0.283]  \n",
      "100%|██████████| 79/79 [00:01<00:00, 56.00it/s, val_acc=0.972, val_loss=0.0912]\n",
      "<<<<<< reach best val_acc : 0.972000002861023 >>>>>>\n",
      "\n",
      "================================================================================2023-08-02 09:27:56\n",
      "Epoch 2 / 5\n",
      "\n",
      "100%|██████████| 469/469 [00:09<00:00, 50.56it/s, train_acc=0.968, train_loss=0.105] \n",
      "100%|██████████| 79/79 [00:01<00:00, 40.13it/s, val_acc=0.98, val_loss=0.0672] \n",
      "<<<<<< reach best val_acc : 0.9800000190734863 >>>>>>\n",
      "\n",
      "================================================================================2023-08-02 09:28:08\n",
      "Epoch 3 / 5\n",
      "\n",
      "100%|██████████| 469/469 [00:09<00:00, 51.60it/s, train_acc=0.972, train_loss=0.0926]\n",
      "100%|██████████| 79/79 [00:01<00:00, 55.83it/s, val_acc=0.964, val_loss=0.121] \n",
      "<<<<<< val_acc without improvement in 1 epoch, early stopping >>>>>>\n"
     ]
    }
   ],
   "source": [
    "import os,sys,time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datetime \n",
    "from tqdm import tqdm \n",
    "\n",
    "import torch\n",
    "from torch import nn \n",
    "from copy import deepcopy\n",
    "from torchmetrics import Accuracy\n",
    "#注：多分类使用torchmetrics中的评估指标，二分类使用torchkeras.metrics中的评估指标\n",
    "\n",
    "def printlog(info):\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "    print(str(info)+\"\\n\")\n",
    "    \n",
    "net = create_net() \n",
    "\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   \n",
    "metrics_dict = {\"acc\":Accuracy(task='multiclass',num_classes=10)}\n",
    "\n",
    "\n",
    "# =========================移动模型到GPU上==============================\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net.to(device)\n",
    "loss_fn.to(device)\n",
    "for name,fn in metrics_dict.items():\n",
    "    fn.to(device)\n",
    "# ====================================================================\n",
    "\n",
    "\n",
    "epochs = 5 \n",
    "ckpt_path='checkpoint.pt'\n",
    "\n",
    "#early_stopping相关设置\n",
    "monitor=\"val_acc\"\n",
    "patience=1\n",
    "mode=\"max\"\n",
    "\n",
    "history = {}\n",
    "\n",
    "for epoch in range(1, epochs+1):\n",
    "    printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n",
    "\n",
    "    # 1，train -------------------------------------------------  \n",
    "    net.train()\n",
    "    \n",
    "    total_loss,step = 0,0\n",
    "    \n",
    "    loop = tqdm(enumerate(dl_train), total =len(dl_train),file=sys.stdout)\n",
    "    train_metrics_dict = deepcopy(metrics_dict) \n",
    "    \n",
    "    for i, batch in loop: \n",
    "        \n",
    "        features,labels = batch\n",
    "        \n",
    "        # =========================移动数据到GPU上==============================\n",
    "        features = features.to(device)\n",
    "        labels = labels.to(device)\n",
    "        # ====================================================================\n",
    "        \n",
    "        #forward\n",
    "        preds = net(features)\n",
    "        loss = loss_fn(preds,labels)\n",
    "        \n",
    "        #backward\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "            \n",
    "        #metrics\n",
    "        step_metrics = {\"train_\"+name:metric_fn(preds, labels).item() \n",
    "                        for name,metric_fn in train_metrics_dict.items()}\n",
    "        \n",
    "        step_log = dict({\"train_loss\":loss.item()},**step_metrics)\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        \n",
    "        step+=1\n",
    "        if i!=len(dl_train)-1:\n",
    "            loop.set_postfix(**step_log)\n",
    "        else:\n",
    "            epoch_loss = total_loss/step\n",
    "            epoch_metrics = {\"train_\"+name:metric_fn.compute().item() \n",
    "                             for name,metric_fn in train_metrics_dict.items()}\n",
    "            epoch_log = dict({\"train_loss\":epoch_loss},**epoch_metrics)\n",
    "            loop.set_postfix(**epoch_log)\n",
    "\n",
    "            for name,metric_fn in train_metrics_dict.items():\n",
    "                metric_fn.reset()\n",
    "                \n",
    "    for name, metric in epoch_log.items():\n",
    "        history[name] = history.get(name, []) + [metric]\n",
    "        \n",
    "\n",
    "    # 2，validate -------------------------------------------------\n",
    "    net.eval()\n",
    "    \n",
    "    total_loss,step = 0,0\n",
    "    loop = tqdm(enumerate(dl_val), total =len(dl_val),file=sys.stdout)\n",
    "    \n",
    "    val_metrics_dict = deepcopy(metrics_dict) \n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, batch in loop: \n",
    "\n",
    "            features,labels = batch\n",
    "            \n",
    "            # =========================移动数据到GPU上==============================\n",
    "            features = features.to(device)\n",
    "            labels = labels.to(device)\n",
    "            # ====================================================================\n",
    "            \n",
    "            #forward\n",
    "            preds = net(features)\n",
    "            loss = loss_fn(preds,labels)\n",
    "\n",
    "            #metrics\n",
    "            step_metrics = {\"val_\"+name:metric_fn(preds, labels).item() \n",
    "                            for name,metric_fn in val_metrics_dict.items()}\n",
    "\n",
    "            step_log = dict({\"val_loss\":loss.item()},**step_metrics)\n",
    "\n",
    "            total_loss += loss.item()\n",
    "            step+=1\n",
    "            if i!=len(dl_val)-1:\n",
    "                loop.set_postfix(**step_log)\n",
    "            else:\n",
    "                epoch_loss = (total_loss/step)\n",
    "                epoch_metrics = {\"val_\"+name:metric_fn.compute().item() \n",
    "                                 for name,metric_fn in val_metrics_dict.items()}\n",
    "                epoch_log = dict({\"val_loss\":epoch_loss},**epoch_metrics)\n",
    "                loop.set_postfix(**epoch_log)\n",
    "\n",
    "                for name,metric_fn in val_metrics_dict.items():\n",
    "                    metric_fn.reset()\n",
    "                    \n",
    "    epoch_log[\"epoch\"] = epoch           \n",
    "    for name, metric in epoch_log.items():\n",
    "        history[name] = history.get(name, []) + [metric]\n",
    "\n",
    "    # 3，early-stopping -------------------------------------------------\n",
    "    arr_scores = history[monitor]\n",
    "    best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n",
    "    if best_score_idx==len(arr_scores)-1:\n",
    "        torch.save(net.state_dict(),ckpt_path)\n",
    "        print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n",
    "             arr_scores[best_score_idx]))\n",
    "    if len(arr_scores)-best_score_idx>patience:\n",
    "        print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n",
    "            monitor,patience))\n",
    "        break \n",
    "    net.load_state_dict(torch.load(ckpt_path))\n",
    "    \n",
    "dfhistory = pd.DataFrame(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用GPU后每个Epoch只需要10秒钟左右，提升了4倍。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四，torchkeras.KerasModel中使用GPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上面的例子可以看到，在pytorch中使用GPU并不复杂，但对于经常炼丹的同学来说，模型和数据老是移来移去还是蛮麻烦的。\n",
    "\n",
    "一不小心就会忘了移动某些数据或者某些module，导致报错。\n",
    "\n",
    "torchkeras.KerasModel 在设计的时候考虑到了这一点，如果环境当中存在可用的GPU，会自动使用GPU，反之则使用CPU。\n",
    "\n",
    "通过引入accelerate的一些基础功能，torchkeras.KerasModel以非常优雅的方式在GPU和CPU之间切换。\n",
    "\n",
    "详细实现可以参考torchkeras.KerasModel的源码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:31:06.804407Z",
     "iopub.status.busy": "2023-08-02T09:31:06.803922Z",
     "iopub.status.idle": "2023-08-02T09:31:06.819012Z",
     "shell.execute_reply": "2023-08-02T09:31:06.817856Z",
     "shell.execute_reply.started": "2023-08-02T09:31:06.804365Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import  accelerate \n",
    "accelerator = accelerate.Accelerator()\n",
    "print(accelerator.device)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-02T09:31:12.618443Z",
     "iopub.status.busy": "2023-08-02T09:31:12.618036Z",
     "iopub.status.idle": "2023-08-02T09:32:49.031087Z",
     "shell.execute_reply": "2023-08-02T09:32:49.029676Z",
     "shell.execute_reply.started": "2023-08-02T09:31:12.618408Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31m<<<<<< ⚡️ cuda is used >>>>>>\u001b[0m\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAA5UUlEQVR4nO3deXhV5bX48e9KQiamhCSGQJiCEyACggp1gKIoWuuAUmupVVuldahtvdqLcqtWi/bXeutwta1YJ2qsVRxbqxIBp4oVcEAUEZlDQgghgYSQef3+eHfgJJyEE8jJPidZn+fZT87Z7x5WTk722u+w9xZVxRhjjGkuxu8AjDHGRCZLEMYYY4KyBGGMMSYoSxDGGGOCsgRhjDEmKEsQxhhjgrIE0YmJyAYROd3vOEIlIpNEJN/vODojETlTRF7yO45Qted3QUSOFZH322NbXY0lCNMmInK7iDzldxx+EZHvichGEdktIi+JSJ9Wlv22iKwUkQoReV9EhgeUiYj8RkS2iMhOEXlLREYElP9ORDaLyC5vf7cElB0pIi+LSLGI7BCRN0TkqAOEPgf4bbP4REQuFpGFIrJNRLaKyAIRmdb2T8ZfIvKEiNR4n3XjFAugqiuAMhH5ts9hRh1LEMaEyDuAPwxcCmQClcAfW1j2CCAX+AmQAvwDeEVE4rxFpgM/BE4B+gBLgL8GbOJR4GhV7QV8A5gRcOBOAV4BjvLi+BB4uZW4jwd6q+oHAfNigaeBq4C7gBxgAHA7MFNEHhEROcBHEml+p6o9Aqb6gLJc4Md+BRatLEF0fseLyBciUioij4tIYmOBiJwjIp+ISJl3hntsQNl/e2e35SKyWkROE5GpwC3Axd4Z2qfNd+atN7/ZvPtF5AHv9RUissrb7joRafM/rYjMEpG13ja+EJELmpVfFbCPL0TkOG/+ABF5wTvzLhGRB9u46xnAP1T1HVWtAH4FTBORnkGWPRN4V1XfU9U64P8B/YGJXvkQ4D1VXecdyJ4C9tYwVHW1qu4O2F4DcLhX9qGqPqqqO1S1FrgXOEpE0lqI+yzg7WbzZgPVwBRVXaiqFapaq6rvA2cDvXCJEAAROVpE8rway2oR+U5A2RMi8mevvFxE3haRQQHl3xCRpV5NaamIfCOgrI/3vSzwvqMvBQYpIv/l1W4KReSKFn6/ULwFnCYiCYewja5HVW3qpBOwAViJOzPsA/wb+I1XNgbYBpwIxAKXecsn4M5MNwP9vGUHA0O917cDT7Wyz0G4M+ue3vtYoBAY773/FjAUENzBshI4ziubBOSH8HtNB/rhTnAuBnYDWQFlW4DjvX0c7sUUC3yKO5h2BxKBk711TgbKWpkal3sZ+O9msVQAY4PEeB3wr4D3sUAV8LOAz2k5cCTQDfgd8FKzbczytq/AOiC7hc/jfKCwlc/rOeCmgPfdgU3ez3jgMe+78BbwOHAqLpktC1h+M3AFEOd9d7YDw73yJ4Byb70E4H5c8gP3vSvFJZs44BLvfZpX/irwdyDV+xwmBnwX6oA7vPlne9+V1BZ+xyeAHd60HLgwyDK7gGP9/r+Mpsn3AGwK4x/XHfB/EvD+bGCt9/pPwJ3Nll+NO2gf7h0wTge6NVvmdlpJEN4y7wE/8F5PadxnC8u+FHDQnEQICSLINj4BzvNev9G4vWbLTACKgbhD+DwXBn6e3rwtwKQgyx6NS1yTvIPwr3C1gJu98njvQKregXA9MCTIdsQ7IP8aL+k2K8/2Yriklbjzmn0PTgee9F5fDSzA1RhygILG3wdY7/28GFcbCtzmw8Bt3usngGcCynoA9bgTk0uBD5utuwS4HMjyPpP9Dvre57Yn8O/lfSfHt/A7Hgek4ZLQ2biEdVKQv9Wp7fX/1RUma2Lq/DYHvN6IO/MGdwb7X17zUpmIlOH+ofup6tfAz3HJYJuIPCMi/Qjd07gzRYDvee8BEJGzROQDr6miDPfPnN6WX0hEfhDQNFYGHBOwjQHA2iCrDQA2qmvuOVgVuANpoF64g1ETqvolrlb2IK4GlQ58ATSOzLkVV8sZgKvN/BpYJCLJzbajqvox7mD568AyEcnAHdz/qKp/ayXuUiCwGeww3MESYCSu5rJLVdfhkjtes1ljE9cg4MRm35UZQN+Abe79nqlrftuB+671w33vAm3E1VAGADtUtbSFuEua/b0qcclnP6r6kaqWqGqdqv4L1+fQvLO9J65GaEJkCaLzGxDweiDuDBHcP/QcVU0JmJIbDzSq+rSqnow7OCiuDR3v9YE8B0wSkWzgArwE4bX/Pg/cA2SqagrwL9xZcki8tu1HcE04ad42VgZsYzOuCau5zcDAgE7iwG2e0mz0S/PpFG/Rz4FRAevl4JpUvgoWq6rOV9VjVDUNuA3XVLfUKx4N/F1V872D2hO4ZpbhwbaFOzPe+3uJSCouObyiqnNaWKfRClxTVqPtuLN3gM+A80Wkp4gMwTW3peI63x/zltkMvN3su9JDVa8O2Obe75mI9MA1LRV40yCaGohLUJuBPiKScoD4D4YS8L0Skf64WtvqMOyr8/K7CmNT+CZcE9NnuGaIPrizw7u8snG4f9ATcf9I3XH9Az1xfRCTcQe/xjbqxiaJn3jbiTnAvl/DNW18HDCvJ67pYaK3z7NwZ4WN/SKTOEATE+4AWuXFGItrF68DrvTKp3u/11iC90Hcw74+iJPa+HmOwLVjn+Jt4ykCmlaCLD/W228G8CzwdEDZbd7nmIk7UbsUd8ae4r3/Me5ALcAJuFrI9d66vXAjlx4MMe7jgK+a/R02AEkBf98iYBGuuehj4JrGv7G3/EYvxm7edDwwzCt/wvtcTva2dy/wb68sDXfW/j1ckrvYe5/ulb+KO4Fo7IM4taXvghfz6S38jhfhahcxwBm4Wt2kgPLvEdAnZFOI33m/A7ApjH9c9w91M65powx4EkgOKJ+KO6Mt8w5Az3kHg2O9A1A5rqngn+zrsE7zDmylwEet7PtS3FncTc3mX+sdjMpwwzqfoQ0JwltujhfXduAPuBE6VwaU/wR3pliBq12M8eYPxPV5lHjrPnAQn+n3cB28u3Gd1n0Cyl4Dbgl4/17AZ/gw0D2gLBF4yPvcdwEfAVO9shjgdW+9ClwN5RZAvPLLvM92t1feOA1sJe6lwIkB7+8CHm5h2f36aXAJ+VVcP04JLpmM9sqeAP6MOyGoAN4hoD8FlziWAzu9nycHlPXxvpdF3nfqhZa+C7SeIN71tr8LdyLw3WblrwLn+v0/GW1T4xfOGNOJicgZwDWqer73Pg53QhAD/AbX0Z8MnAf8EtcZXBHitp/AHcz/p90Dbwfe8O2HVXWC37FEG0sQxnRRIhKDG030Q2AYUAMsxjVDrmzDdp4gghOEOXj7ddgZ4zcRGYhrFgtmuKpu6sh4OitVbcD1Pzx2oGVN12Q1CGOMMUHZMFdjjDFBdZompvT0dB08eLDfYRhjTFRZvnz5dlXNCFbWaRLE4MGDWbZsmd9hGGNMVBGR5le672VNTMYYY4KyBGFMtCopgalT3U9jwsAShDHR6okn4I034Mkn/Y7EdFKWIIyJRqpw773u9b33uvfGtDNLEMZEo3ffhZ073euyMnjvPV/DMZ2TJQhjotCm3/6W+t3ucQ31u3ez6e67fY7IdEaWIIyJdOedByJNpsy8PGK9ZqVYVTLz8vZbhvPO8zlwE+0sQZiwyS0qYvCSJcS89RaDlywht6jI75Ci0113wcCBaGLi3lkJdU0fjNfkfWIiDBrk1jPmEFiCMGGRW1TEzNWr2VhdjQIbq6uZuXq1JYkDKK+rY0VFBS8VF/OHzZu57quvOLu+njFPPsn88ePZHZAkgtHkZFdz+PxzGDGig6LuXOzEZp9OcyW1iSw3rV1LZUNDk3mVDQ38ZPVqvt6zh6z4ePrGx5PlTZnx8XSL6fznK3UNDeRXV7O+qop1VVWs27Onyc/ttbVNlu8dG8vQpCQOz8hg6V/+wmHPPceJt91GYk3Nftuu6taNm6++mtIf/YgL9+xhSmIiibGxHfWrdQqNJzaN393GExuAGZmZfobmC0sQpt3sqa9nfnExDxcUUBjkAAZQ0dDA7Rs2BC1L79Ztb8LYm0ASEpq+j4+nR1xkf21La2tZV1XF+mYH/3V79rCxupq6gCGpcSIMSkggJymJC9PTyUlKYkhiIjlJSeQkJpLarVvTjU+eTM1dd0GQz1cTEuh9wgk8UVLCk0VF9IyN5Zy0NC7MyOCsPn1ItmTRosr6et4pK+Pqr74KemJz9VdfsaO2du/fZkhiIkld4POM7P80ExVW7d7N3MJCnty6ldK6Oo5ISiIlLo6yZu3kAIMSEvjqxBPZVlNDYcC0taaGwurqve9XVVaytaaG2iDj+3vExu5XA8lKSGj6Pj6ePt26ESOy3/rB5BYVMXvdOjZVVzMwIYE5OTktnjHWNDSwyasBrA9SC2j+e6d360ZOYiLH9+rFxQEHmJzERLITEohrS81p2TLive03iLAnPp6kmhpiVEmqr+f2khJuuegiFpeVMb+4mJe2b+dv27aRFBPD2X36cGFGBt9KS6NXhCfZcKtX5ePycvJKS8krLeXfO3dS08q1JOX19Vz/9ddN5mXFx5MT+Pf0knpOUhJZ8fEhf/ciWVifByEiU4H7cQ9u/4uq/rZZ+SDcw0oycM/f/b6q5ntlvwO+hesnyQN+pq0EO27cOLWb9XWcqvp6nt++nYcLCnh35066iTAtPZ0f9+vHpJQUnt62rUlVHSA5Joa5Rx0VclW9QZUdtbUuebSSTLbW1FBeX7/f+t1EyGxeIwmolTS+X1RayjVr1jSJNSkmhv8eMIAjk5P3SwKbq6sJPMeMF9nvAJGTmMgQ78DRrgfjSy6BZ55xHdGZmXDfffCzn8G2bVBV5cqffnrv4nUNDbyzcyfPFxfzwvbtbK2pIUGEM7xkcW5a2v61lE5qw549exPCwtJSdniJdlT37kzp04cpqalcuXo1m6ur91t3YEICS8eODXpCsG7PHvKbfScSRBgc5DvRmEx6ttN3IjcXZs+GTZtg4ECYMwdmzGjbNkRkuaqOC1oWrgQhIrG4h61PAfJxD02/RFW/CFjmOeCfqvqkiEwGrlDVS0XkG8DvgVO9Rd8DblbVt1ranyWIjvHl7t084tUWSurqGJqYyMx+/bi8b18Oi49vsmxbzsoP1e76+r1Jo0lCqa5u8r64WRt/qLLi44MmgQ4/W8zJcUeDiy6CRx+F7t1h92744Q/h+efd6KW1a4Ou2qDKkl27mF9czPPFxWyuriZOhNNSUrgwI4Pz09PJaPY3jGZltbUsKitzSWHHDtZWVQHQPz5+b0I4LTWVzIDfuXkfBIR2YhNYqwyWQHY2O4FprFU2TxxtqVXm5sIVuUXU/mAdHFYN2xLoNi+Hx2dktilJ+JUgJgC3q+qZ3vubAVT17oBlPgemqupmERFgp6r28tZ9EDgZEOAd4FJVXdXS/ixBhE91QwMveH0Lb+/cSZwI56en8+OsLCanpkZVVbq2oYFttbVNksnMr74KuqwAK48/nsGJiZHTfv+tb8GFF7qE0Nxjj7kk8eqrB9yMqrKsvHxvslhbVUUMcGpKChdlZHBBejr9EhLaP/4wqmlo4INdu/YmhKXl5TTgmiQnpaQwJTWVKampHJ2cjLTynQ3HiU1jv1Tz5LH+AP1SOYmJxG1PondlIj12JZFYmkjtjm4ceSRc+UwRJZevhsSAuktVDGlPHMX2Z0KP168EcRHu4H+l9/5S4ERVvS5gmaeB/6jq/SIyDXgeSFfVEhG5B7gS93/6oKrODrKPmcBMgIEDB47duLHF25qbg7CmspK5hYU8sXUr270OuplZWVzety99o+zg0ZrBS5awMUizwqCEBDZMmOBDRB1LVVmxe/feZLGqshIBJvTqxUUZGUzLyGDQAYbX+kFVWVVZuTchvL1zJxX19cQAJ/TqtTchjO/Vq00j5A612aauDioroVcv9/7DD2HLFndnlLIy9zMzE37yE1f+/csa+LSgmh2JVexK3kNlShV9x+wh+/iqoCPbKI8jtSqR0p6VTZNDo60J6HdD/962liD87qm6EXhQRC7H1RK2APUicjgwDMj2lssTkVNU9d3AlVV1LjAXXA2iw6LuxGoaGnhx+3bmFhSwqKyMWOA8r2/h9CirLYRqTk5O0GaFOTk5PkbVcUSEUT16MKpHD+4cMoRVu3fzfHEx84uLuWHtWm5Yu5ZxPXtyUUYGF6anc3hysm+xFtXU8KaXEN4sLWWLN5rr8KQkLs3MZEpqKt9MSSHlIPtVnnoKZs6EPXvc+40b4cor4f333WUljQf5bt3gN79xy1x/Pbz55r6yykq37MqVrvznP4clS5ru59RT9yWI+NgYBsYnMbJnEikpqfTuDcf0hBljXfm/3qqjOK6KsuQ9lMRXUZSxh831VbxWUhH8l8jc/2TnYPnaxNRs+R7Al6qaLSI3AYmqeqdXditQpaq/a2l/1sR0aL6urOSRwkIe37qV4tpaBiUkcFW/fvywb1+yOlFtoSUd2V8STb6urOSF7duZX1zM0vJyAI7t3t0li4wMhnfvfsBtHMoZeWV9Pe/u3Enejh3klZaywrv/VJ+4OE7zaghTUlMZnJREQwPs2uUO1Dt3wvDhEBcHy5fDf/7TtGzXLpg3D2Ji4Ne/dq1zjWUHEh/vuoJWeQ3ed9wBK1ZASgr07u1+ZmfDFVe48hUroL5+X3mvXi6uQ5X+5hJK4vZPBml1CWw/vX1qEOFMEHG4TurTcDWDpcD3VPXzgGXSgR2q2iAic4B6Vb1VRC4GrgKm4pqYXgfuU9V/tLQ/SxBtV9PQwMvbtzO3sJA3S0uJBb7t9S1M6dOH2E5YW+gs2mP0SlttqqriBa9m8f6uXShwdHIyF6anc1FGBqN69NivbT83152RV1bum5ecDHPnNo23vt4dtEvLlP/sqOCtih2s7lHKkgo3/DS2Qehf2pu++amkfJ1KzLqePPqI0K8fPPgg3HILePlrr61bXVPOrbfCnXe6eXFx7iDdu7c7cHfv7h6rsXixO4A/8EDw310ECgrcMpHS2pZbVMQPP19NTcy+mm98QwyPjQh9pCD4lCC8HZ8N3Icb5vqYqs4RkTuAZar6itdPcTeguCama1W12hsB9UfcKCYFXlfVG1rblyWI0K3bs8fVFgoLKaqtZWBCAldmZfHDrCz6d4HaQrQL9aDbFg0N7tq7wKlnTzdVVbmz5cCyrTXVbBiwnYV1xbxdVkYDkF6TyKidGRxTkkFmWU/OP0846yzYeHgRXLlvpA1/ySH+3UzeeQdOPBH+9EIV1zy6A8aVwnGl0NsNPx0q3TmvfyosT+UPl6WQQOzeg3vv3m6079ChsGgRvPLKvrPzxvJvfct9LqWlUF3t5iUmuoN9SwYPds1KzQ0aBC1c3+mr9qj5+pYgOpIliNbVNjTwSkkJcwsKWFBaSgxwTloaP+7XjzOtthDxSkvdGWxBgbvUIdhTRpOTYcqUfQfxiy5y7dzl5TBhgptXW7uv/Kab4Je/hM2bXS2kuT/8AX7xC/jySxg2bP/yuXPhqqsgb2kNZ9y+HSYWw3FlEKdQlMDUHum8Pj8WLs7fb6QNr2Rx4SXKivhS1ngN/r1q4zmmKpVxksrJSamcMS6B3r1drKrQEecu4Ui+kS6SO6lNmG3waguPbd3K1poashMSuH3wYH7Uty/ZkVJXNrz7rjtzbUwCBQVwzDGueQTcGWzzJpTmKivdWW58vJsah9536wZHHbVvfuN0zDGuPCXFtcM3L28cwJWdDS+9tG9+t2772uEBTj02nvUP9SM+vh+7Y2pZWFnCP1OKeXNnAVwW5AQ0sQG+s4XXNYaJSSlc068fU/r0YXgLw0878tKMxiTQ0c13kcpqEJ1QXUMD/ygpYW5hIW/s2IEAZ6elMTMri7P69GnbrR3MQSkvd2f9jWfm8+bBJ59AYeG+BHD44fDaa658+PB9nZ49e0K/fnDWWfueKvr44+5Mtl8/+O533frNRVozSHldHb3efc/1IjanUD3xVOLtu+g7q0F0ERurqvhLYSGPFhZSWFNDv/h4fjVoED/KymKg1RZaFWqn7+7d+w7yZWVw7rlu/u9+5w72jQf/igp3ht14UXNurnsqaP/+kJUF48bBqFH7tvv00y4BZGW5BNFc44iYxn0FawaZM+eQP4Z21TMujkGJCcGvMUlMsOQQBSxBRJFgHVIXZ2Twrx07eLiggNd27ABgap8+/KlfP75ltYWQNG933rjRHZCfeQZeftkNhfzVr9wIl1279q0XH+86cEVcsqirg9Gj4eyz3Zn+oEH7ln3lFbd8S109o0eHHm80NYN09WtMop01MUWJYPeI6SZC95gYyurryYqP50dZWVyZlRWRV71GgvJyWL8e1q3bN915J4wZE3zkCkBxMaSnw7PPwr//7c7w+/VzU1aWaxqy/v3W2TUmkc1GMXUCLd0OIjEmhqeHDeOctLQu8cCd1jQ0uOadwARwxRUwZIjrA7jssqbLp6S4zuFjj3WjZJoTcds0pjOzPogot7u+PmhyAHcjvQsyMjo4Iv/s3t20FjB5sjvAv/uuG+IZ+DHFxLhx9kOGwAknwN13u36BoUPdz9RUt9zAgcFrEMGGfhrTlViCiGCbqqp4aMsW5hYWtrjMwAi+sO1grvZtaHBXwDYmgKOPdgf3DRtg/Hho/njg++93CWLoUHdPnJycfUlg4EA3JBPcdmbNCr7POXOio9PXmI5mCSLCqCof7NrFffn5PF9cjALTMjIYnpzMPZs3R01nX7CO35kz3etp0/bVAlJT4aST3Jn/mDFuvnfbfgBuuMEliMxMOOecpgkgJwf69HHL9evnRvccjGjq9DWmI1kfRISobWhgfnEx9+Xn82F5Ob1jY7mqXz+u699/b6dzNHX2tXTLgtjYfRdwAUyf7jqAAX7wAzjssH1JICfHjQSK4EqSMVHP+iAiWEltLXMLCnhwyxYKamo4IimJB484gssyM+nR7JaPMzIzIzYhNCoocFfdtjQqqL7ejRxqTACHH76vbN68DgnRGBMiSxA++Xz3bh7Iz2deURFVDQ2cnprK3KOO4qw+faL2mQs33gj/+7/udVycuy6guUGD4H/+p2PjMsYcHEsQHahBldd37OC+/HzySktJjInh0sxMru/fn2N69PA7vJCputtGvPiiqy289pq7QnjiRNenMG0afPSRdfwaE+0sQXSAiro65hUVcX9+Pl/t2UNWfDxzhgxhZlYW6VH0kPgtW1wN4cUX3aiimBg45RR3Z9H+/eHb33YT7Lv7p3X8GhO9rJM6jDZVVfHgli08UlhIWV0d43r25BfZ2VyUkREV96Gprnb32u/RwyWCwkLXb3DaaXDBBe4+RF3oEgxjOiXrpO5AqsoSb5jqC94w1QszMvh5djYTevUKejvjSFJRAa+/Di+8AK++6u49dP75LkFkZbnago+PJDbGdCBLEO2kJmCY6tLyclLi4rhhwACuDRimGqkqK/cd9M84wz1gPT3dDUG94AJXY2hkycGYrsMSxCHaXlPD3MJCHvKGqR6ZlMQfjziCH/TtS/fYWL/Da1F+vutgfuEFWLbMNR917w633+6uOzjppPZ5sLoxJnrZIeAgrayo4P4tW3jKG6Z6RmoqfznqKM6M8GGq774L//VfsHSpez9sGPz0p66/oXt3V4MwxhiwBNEmDaq85g1TfdMbpvqDzEyuz85mRPfufoe3H1U33PTFF92N7CZO3Pcwmrvucs1HRx/tb4zGmMhlCSIEFXV1POkNU12zZw/94uO5a8gQrorAYaoNDa6W8OKLbtq0yd3eolcvlyBGj4YPP/Q7SmNMNIj8sZY+2lhVxU1r15K9ZAnXrVlDalwcTw8bxobx47l50KAOTw65ue4eRzEx7mdurptfVQUrV7rXIvD978Of/+weafnYY+7uqL/8ZYeGaozpBLp8DWK/G+ANGcLgpKS9w1QFuMgbpjq+d2//4gxyd9Qf/hAefNAlhx493IVsMTHwj3+4exxF0cXZxpgI1KUvlAv2GM8YoAFIjYtjZlYW1/bvz4AIGKba0t1RY2LgRz9yt7eYMsU1JxljTKjsQrkWzF63rklyAJcc+sTFsWnChIgaprppU/D5qjB3bsfGYozpGrp0H8SmFh7jWVpXF1HJAVp+/KU9FtMYEy5dOkG09LjOSHyM55w5+z84x+6OaowJpy6dIObk5JDc7KZ5kfoYzxkz4Mwz3WsR91yFuXPt7qjGmPDp0n0QjU9ni5bHeBYUwMknu+scjDEm3Lp0goDoeIwnuLuoLl8Ov/6135EYY7qKLt3EFE0WLnQjlqZM8TsSY0xXYQkiSowcCbfeCuOCjlY2xpj21+WbmKLFsGHWvGSM6VhWg4gChYWQl+duyW2MMR3FEkQUeOEF95yG/Hy/IzHGdCVhTRAiMlVEVovI1yIyK0j5IBFZKCIrROQtEckOKBsoIgtEZJWIfCEig8MZayRbsACGDIGhQ/2OxBjTlYQtQYhILPAQcBYwHLhERIY3W+weYJ6qHgvcAdwdUDYP+L2qDgNOALaFK9ZIVlsLixfb6CVjTMcLZw3iBOBrVV2nqjXAM8B5zZYZDizyXi9uLPcSSZyq5gGoaoWqVoYx1oj14YdQXm4JwhjT8cKZIPoDmwPe53vzAn0KTPNeXwD0FJE04EigTEReEJGPReT3Xo2ky1m82N1aY/JkvyMxxnQ1fndS3whMFJGPgYnAFqAeN/z2FK/8eCAHuLz5yiIyU0SWiciy4uLiDgu6I918M3z6KfTp43ckxpiuJpwJYgswIOB9tjdvL1UtUNVpqjoGmO3NK8PVNj7xmqfqgJeA45rvQFXnquo4VR2XkZERnt/CZ7Gx7iI5Y4zpaOFMEEuBI0RkiIjEA98FXglcQETSRaQxhpuBxwLWTRGRxqP+ZOCLMMYakRYtgp/+FEpL/Y7EGNMVhS1BeGf+1wFvAKuAZ1X1cxG5Q0TO9RabBKwWka+ATGCOt249rnlpoYh8BgjwSLhijVTPPw+PPw7du/sdiTGmK+rSz6SOdEce6aZ//tPvSIwxnVVrz6T2u5PatGDDBlizxl1BbYwxfrAEEaHy8txPu/7BGOMXSxARqqbG3dr76KP9jsQY01VZgohQ114LS5e6i+SMMcYPliAiUG2te3qcMcb4yRJEBPr9793dWyu75N2njDGRwhJEBMrLg9RUSE72OxJjTFdmCSLCVFTAv/9to5eMMf6zBBFh3nnH9UFYgjDG+M0SRITJy4OEBDj5ZL8jMcZ0dXF+B2CaOvtsGDgQkpL8jsQY09VZgogwU6ZY85IxJjJYE1MEWbkSPvnEroEwxkQGSxAR5Le/hTPPtARhjIkMliAihCq8+SacfjrE2F/FGBMB7FAUIT77DIqKrP/BGBM5LEFEiAUL3E9LEMaYSGEJIkIsXAjDh0P//n5HYowxjg1zjRDPPgubN/sdhTHG7GM1iAjRs6erQRhjTKSwBBEB5s51Q1yNMSaSWIKIAH/+M7z2mt9RGGNMU5YgfLZtG3z8MZxxht+RGGNMU5YgfLZwoftpw1uNMZHGEoTPGp8eN3as35EYY0xTliB8pgrnnAOxsX5HYowxTdl1ED57/HG/IzDGmOCsBuGj+nq/IzDGmJZZgvDRtGkwfbrfURhjTHCWIHxSWwuLFkFGht+RGGNMcJYgfPLBB1BRYcNbjTGRyxKET/Ly3IOBvvlNvyMxxpjgLEH4ZMECOPFESEnxOxJjjAnOhrn65IoroHdvv6MwxpiWWYLwyY9/7HcExhjTOmti8sEHH0BBgd9RGGNM60JKECJygYj0DnifIiLnh7DeVBFZLSJfi8isIOWDRGShiKwQkbdEJLtZeS8RyReRB0OJM1rMmAE/+YnfURhjTOtCrUHcpqo7G9+oahlwW2sriEgs8BBwFjAcuEREmj8z7R5gnqoeC9wB3N2s/E7gnRBjjApr18K6dXZ7b2NM5As1QQRb7kD9FycAX6vqOlWtAZ4Bzmu2zHBgkfd6cWC5iIwFMoEFIcYYFfLy3E+7/sEYE+lCTRDLROQPIjLUm/4ALD/AOv2BzQHv8715gT4FpnmvLwB6ikiaiMQA/wvc2NoORGSmiCwTkWXFxcUh/ir+ysuDAQPgyCP9jsQYY1oXaoL4KVAD/B1XE6gCrm2H/d8ITBSRj4GJwBagHrgG+Jeq5re2sqrOVdVxqjouIwruWVFf726vMWUKiPgdjTHGtC6kYa6quhvYr5P5ALYAAwLeZ3vzArdbgFeDEJEewIWqWiYiE4BTROQaoAcQLyIVqtrWGCJKbCysWOHuw2SMMZEu1FFMeSKSEvA+VUTeOMBqS4EjRGSIiMQD3wVeabbddK85CeBm4DEAVZ2hqgNVdTCuljEv2pNDowEDICfH7yiMMebAQm1iSvdGLgGgqqXAYa2toKp1wHXAG8Aq4FlV/VxE7hCRc73FJgGrReQrXIf0nLaFH11+9St48UW/ozDGmNCIqh54IZHlwAWqusl7Pxh4QVWPC294oRs3bpwuW7bM7zBaVF4OffrAjTfC3c0H8xpjjE9EZLmqjgtWFuqtNmYD74nI24AApwAz2ym+LuHtt6Guzoa3GmOiR6id1K+LyDhcUvgYeAnYE8a4Op28PEhKgpNO8jsSY4wJTUgJQkSuBH6GG4n0CTAeWAJMDltknUxeHkycCAkJfkdijDGhCbWT+mfA8cBGVf0mMAYoC1dQnc3u3ZCYCGee6XckxhgTulD7IKpUtUpEEJEEVf1SRI4Ka2SdSPfu8NFHEMJ4AGOMiRihJoh87zqIl4A8ESkFNoYrqM6mocE9XtSunjbGRJOQmphU9QJVLVPV24FfAY8C54cxrk6joQGGDoV77vE7EmOMaZs2P1FOVd8ORyCd1aefwoYNkJnpdyTGGNM29kS5MFvg3az89NP9jcMYY9rKEkSY5eXBMcdAVpbfkRhjTNtYggijPXvgvffs6mljTHRqcx+ECV11NcyaBVOn+h2JMca0nSWIMEpJgdtv9zsKY4w5ONbEFEZvv+2uojbGmGhkCSJMiopg0iT4v//zOxJjjDk4liDC5M033U/roDbGRCtLEGGSlwdpaTBmjN+RGGPMwbEEEQaqLkGcfrq7B5MxxkQjO3yFwapVUFBgzUvGmOhmw1zD4OijYflyGDTI70iMMebgWYIIg5gYOO44v6MwxphDY01M7ay6Gq69Fj7+2O9IjDHm0FiCaGdLlsAf/wibNvkdiTHGHBpLEO0sLw9iY+Gb3/Q7EmOMOTSWINrZggUwfjz06uV3JMYYc2gsQbSjkhI3esmGtxpjOgNLEO1o/Xro188ShDGmc7Bhru1o3DjYvNnvKIwxpn1Ygmgnqu6niL9xGGNMe7Empnby9dfQvz8sXOh3JMYY0z4sQbSTvDwoLLTbaxhjOg9LEO0kLw8GD4ahQ/2OxBhj2ocliHZQVweLFrnRS9YHYYzpLCxBtIMPP4Rdu+CMM/yOxBhj2o8liHaQmgrXXAOTJ/sdiTHGtJ+wJggRmSoiq0XkaxGZFaR8kIgsFJEVIvKWiGR780eLyBIR+dwruziccR6qYcPgoYegTx+/IzHGmPYTtgQhIrHAQ8BZwHDgEhEZ3myxe4B5qnoscAdwtze/EviBqo4ApgL3iUhKuGI9FBUVsGwZ1Nf7HYkxxrSvcNYgTgC+VtV1qloDPAOc12yZ4cAi7/XixnJV/UpV13ivC4BtQEYYYz1oCxfC8cfDu+/6HYkxxrSvcCaI/kDgjSfyvXmBPgWmea8vAHqKSFrgAiJyAhAPrG2+AxGZKSLLRGRZcXFxuwXeFnl50L07TJjgy+6NMSZs/O6kvhGYKCIfAxOBLcDexhoRyQL+Clyhqg3NV1bVuao6TlXHZWT4U8FYsAAmToSEBF92b4wxYRPOezFtAQYEvM/25u3lNR9NAxCRHsCFqlrmve8FvArMVtUPwhjnQdu4EdascSOYjDGmswlnDWIpcISIDBGReOC7wCuBC4hIuog0xnAz8Jg3Px54EdeBPT+MMR6SvDz3027vbYzpjMKWIFS1DrgOeANYBTyrqp+LyB0icq632CRgtYh8BWQCc7z53wFOBS4XkU+8aXS4Yj1Y3/uea2Ia3nxsljHGdAKijfepjnLjxo3TZcuW+R2GMcZEFRFZrqrjgpX53Ukdtb74Am69FYqK/I7EGGPCwxLEQXr5ZbjzTrs5nzGm87IEcZDy8mDUKDjsML8jMcaY8LAEcRAqK+Hf/7bRS8aYzs0SxEF45x2oqbHbextjOjdLEAdh82ZIT4eTT/Y7EmOMCR9LEAfhqqtg61ZISvI7EmOMCR9LEG3UeNlIbKy/cRhjTLhZgmijp55yo5e2bvU7EmOMCS9LEG20YAEUFtrwVmNM52cJog1U3fUPp58OMfbJGWM6OTvMtcFnn7lba9j1D8aYrsASRBvY7b2NMV2JJYg2GD4crr8esrP9jsQYY8IvnE+U63TOOstNxhjTFVgNIkRbtrhHjBpjTFdhCSJE//d/cMQRsHu335EYY0zHsAQRorw8mDABunf3OxJjjOkYliBCsH07fPyxjV4yxnQtliBCsHChu0jObu9tjOlKLEGEYMECSE2FsWP9jsQYYzqOJYgQzJkDL75od3A1xnQtdh1ECPr2dZMxxj+1tbXk5+dTVVXldyhRKTExkezsbLp16xbyOpYgDuDVV2H9erj6aqtBGOOn/Px8evbsyeDBgxERv8OJKqpKSUkJ+fn5DBkyJOT1rInpAB5+GO6915KDMX6rqqoiLS3NksNBEBHS0tLaXPuyBNGK2lpYvNiGtxoTKSw5HLyD+ewsQbTigw+gosKGtxpjuiZLEK3Iy3MPBpo82e9IjDFtlZsLgwe7/+HBg9170zaWIFpRUADjx0NKit+RGGPaIjcXZs50N9hUdT9nzjy0JFFWVsYf//jHNq939tlnU1ZWdvA79pGoqt8xtItx48bpsmXL2n27tbXQhlFhxpgwWbVqFcOGDdv7ftKk/Zf5znfgmmtg4EDYvHn/8rQ0d+uc7dvhooualr31Vuv737BhA+eccw4rV65sMr+uro64uOgYENr8MwQQkeWqOi7Y8laDOABLDsZEn/z84PNLSg5+m7NmzWLt2rWMHj2a448/nlNOOYVzzz2X4cOHA3D++eczduxYRowYwdy5c/euN3jwYLZv386GDRsYNmwYV111FSNGjOCMM85gz549Le7vkUce4fjjj2fUqFFceOGFVFZWAlBUVMQFF1zAqFGjGDVqFO+//z4A8+bN49hjj2XUqFFceumlB/+LBlLVTjGNHTtW29MvfqE6bVq7btIYcwi++OKLkJcdNEjVNS41nQYNOvj9r1+/XkeMGKGqqosXL9bk5GRdt27d3vKSkhJVVa2srNQRI0bo9u3bvVgGaXFxsa5fv15jY2P1448/VlXV6dOn61//+tcW99e4vqrq7Nmz9YEHHlBV1e985zt67733qqpqXV2dlpWV6cqVK/WII47Q4uLiJrE0F+wzBJZpC8dVq0G04JVXXPOSMSb6zJkDyclN5yUnu/nt5YQTTmhy0dkDDzzAqFGjGD9+PJs3b2bNmjX7rTNkyBBGjx4NwNixY9mwYUOL21+5ciWnnHIKI0eOJDc3l88//xyARYsWcfXVVwMQGxtL7969WbRoEdOnTyc9PR2APn36tMvvaAkiiPXrYe1au/7BmGg1YwbMnQuDBoGI+zl3rpvfXroHPBzmrbfe4s0332TJkiV8+umnjBkzJuhFaQkJCXtfx8bGUldX1+L2L7/8ch588EE+++wzbrvtNl9uMWIJIoi8PPfTEoQx0WvGDNiwARoa3M9DTQ49e/akvLw8aNnOnTtJTU0lOTmZL7/8kg8++ODQdgaUl5eTlZVFbW0tuQHDr0477TT+9Kc/AVBfX8/OnTuZPHkyzz33HCVeJ8uOHTsOef8Q5gQhIlNFZLWIfC0is4KUDxKRhSKyQkTeEpHsgLLLRGSNN10WzjibW7AAsrPhqKM6cq/GmEiWlpbGSSedxDHHHMNNN93UpGzq1KnU1dUxbNgwZs2axfjx4w95f3feeScnnngiJ510EkcfffTe+ffffz+LFy9m5MiRjB07li+++IIRI0Ywe/ZsJk6cyKhRo7jhhhsOef8QxmGuIhILfAVMAfKBpcAlqvpFwDLPAf9U1SdFZDJwhapeKiJ9gGXAOECB5cBYVS1taX/tOcz1gQeguhqafQeMMT4KNkTTtE1bh7mGc/DuCcDXqrrOC+IZ4Dzgi4BlhgONqW4x8JL3+kwgT1V3eOvmAVOBv4Ux3r2uv74j9mKMMZEtnE1M/YHAS1XyvXmBPgWmea8vAHqKSFqI6yIiM0VkmYgsKy4ubpeg16yBXbvaZVPGGHNA1157LaNHj24yPf74436HBfj/PIgbgQdF5HLgHWALUB/qyqo6F5gLrompPQK68kqorISlS9tja8YY07qHHnrI7xBaFM4axBZgQMD7bG/eXqpaoKrTVHUMMNubVxbKuuFQXg5LltjN+YwxBsKbIJYCR4jIEBGJB74LvBK4gIiki0hjDDcDj3mv3wDOEJFUEUkFzvDmhdXbb7uL4+z23sYYE8YEoap1wHW4A/sq4FlV/VxE7hCRc73FJgGrReQrIBOY4627A7gTl2SWAnc0dliHU14eJCbCSSeFe0/GGBP5wtoHoar/Av7VbN6tAa/nA/NbWPcx9tUoOkReHpx6qksSxpjolltUxOx169hUXc3AhATm5OQwIzPT77Ciit+d1BHl73931z8YY6JbblERM1evprKhAYCN1dXMXL0aoMOSRI8ePaioqOiQfYWLJYgAI0f6HYExJhQ/X7OGT1o5+H6waxfVzS4Crmxo4EdffskjBQVB1xndowf3HXFEu8YZ7exeTJ5HHoF//tPvKIwx7aF5cjjQ/FDMmjWryZDU22+/nd/85jecdtppHHfccYwcOZKXX345pG1VVFS0uF6w5zq09AyIsGvpPuDRNh3K8yDq61UzMlS///2D3oQxJsza9DyI999XFi/ebxr0/vsHvf+PPvpITz311L3vhw0bpps2bdKdO3eqqmpxcbEOHTpUGxoaVFW1e/fuLW6rtrY26HotPdch2DMgDkZbnwdhTUzAihVQXGx3bzWms5iTk9OkDwIgOSaGOTk5B73NMWPGsG3bNgoKCiguLiY1NZW+ffvyi1/8gnfeeYeYmBi2bNlCUVERffv2bXVbqsott9yy33otPddh0aJFzJs3D9j3DIiO0OUTRG4uXHede33LLRAb2773jDfGdLzGjuj2HsU0ffp05s+fz9atW7n44ovJzc2luLiY5cuX061bNwYPHhzScxsOdr2O1qX7IHJzYeZMKCtz77dsce8Dbr1ujIlSMzIz2TBhAg2TJrFhwoR2Gb108cUX88wzzzB//nymT5/Ozp07Oeyww+jWrRuLFy9m48aNIW2npfVaeq5DsGdAdIQunSBmz3b3XQpUWenmG2NMcyNGjKC8vJz+/fuTlZXFjBkzWLZsGSNHjmTevHlNntvQmpbWa+m5DsGeAdERwvY8iI52MM+DiIlxjzJvTsQ9hcoYEznseRCHrq3Pg+jSNYiBA9s23xhjupIu3Uk9Z47rcwhsZkpOdvONMeZQffbZZ3uvZWiUkJDAf/7zH58iapsunSAaRyvNng2bNrmaw5w5NorJmEilqoiI32GEbOTIkXzyySd+hwG4z66tunSCAJcMLCEYE/kSExMpKSkhLS0tqpJEJFBVSkpKSGzjnUi7fIIwxkSH7Oxs8vPzaa/HC3c1iYmJZGdnt2kdSxDGmKjQrVs3hgwZ4ncYXUqXHsVkjDGmZZYgjDHGBGUJwhhjTFCd5kpqESkGQrsRSnDpwPZ2CifcoilWiK54oylWiK54oylWiK54DyXWQaqaEayg0ySIQyUiy1q63DzSRFOsEF3xRlOsEF3xRlOsEF3xhitWa2IyxhgTlCUIY4wxQVmC2Geu3wG0QTTFCtEVbzTFCtEVbzTFCtEVb1hitT4IY4wxQVkNwhhjTFCWIIwxxgTV5ROEiDwmIttEZKXfsRyIiAwQkcUi8oWIfC4iP/M7ppaISKKIfCgin3qx/trvmA5ERGJF5GMR+affsRyIiGwQkc9E5BMRadujFH0gIikiMl9EvhSRVSIywe+YghGRo7zPtHHaJSI/9zuu1ojIL7z/sZUi8jcRadstW1vbdlfvgxCRU4EKYJ6qHuN3PK0RkSwgS1U/EpGewHLgfFXtmAfUtoG4+zF3V9UKEekGvAf8TFU/8Dm0FonIDcA4oJeqnuN3PK0RkQ3AOFWNigu5RORJ4F1V/YuIxAPJqlrmc1itEpFYYAtwoqoeykW4YSMi/XH/W8NVdY+IPAv8S1WfaI/td/kahKq+A+zwO45QqGqhqn7kvS4HVgH9/Y0qOHUqvLfdvCliz0ZEJBv4FvAXv2PpbESkN3Aq8CiAqtZEenLwnAasjdTkECAOSBKROCAZKGivDXf5BBGtRGQwMAaI2GcXek02nwDbgDxVjdhYgfuAXwINPscRKgUWiMhyEZnpdzAHMAQoBh73mvD+IiLd/Q4qBN8F/uZ3EK1R1S3APcAmoBDYqaoL2mv7liCikIj0AJ4Hfq6qu/yOpyWqWq+qo4Fs4AQRicgmPBE5B9imqsv9jqUNTlbV44CzgGu9ptJIFQccB/xJVccAu4FZ/obUOq8Z7FzgOb9jaY2IpALn4ZJwP6C7iHy/vbZvCSLKeO35zwO5qvqC3/GEwmtOWAxM9TmUlpwEnOu16z8DTBaRp/wNqXXemSOqug14ETjB34halQ/kB9Qg5+MSRiQ7C/hIVYv8DuQATgfWq2qxqtYCLwDfaK+NW4KIIl7H76PAKlX9g9/xtEZEMkQkxXudBEwBvvQ1qBao6s2qmq2qg3HNCotUtd3OwtqbiHT3BingNdWcAUTsKDxV3QpsFpGjvFmnARE3sKKZS4jw5iXPJmC8iCR7x4fTcH2T7aLLJwgR+RuwBDhKRPJF5Ed+x9SKk4BLcWe4jcPwzvY7qBZkAYtFZAWwFNcHEfHDR6NEJvCeiHwKfAi8qqqv+xzTgfwUyPW+D6OBu/wNp2Ve0p2COxuPaF6tbD7wEfAZ7pjebrfd6PLDXI0xxgTX5WsQxhhjgrMEYYwxJihLEMYYY4KyBGGMMSYoSxDGGGOCsgRhTAQQkUnRcBdZ07VYgjDGGBOUJQhj2kBEvu895+ITEXnYuyFhhYjc692Tf6GIZHjLjhaRD0RkhYi86N03BxE5XETe9J6V8ZGIDPU23yPgmQm53pWxxvjGEoQxIRKRYcDFwEneTQjrgRlAd2CZqo4A3gZu81aZB/y3qh6Lu8q1cX4u8JCqjsLdN6fQmz8G+DkwHMjBXTlvjG/i/A7AmChyGjAWWOqd3CfhbmXeAPzdW+Yp4AXvGQgpqvq2N/9J4DnvHkr9VfVFAFWtAvC296Gq5nvvPwEG4x4GY4wvLEEYEzoBnlTVm5vMFPlVs+UO9v411QGv67H/T+Mza2IyJnQLgYtE5DAAEekjIoNw/0cXect8D3hPVXcCpSJyijf/UuBt70mA+SJyvreNBBFJ7shfwphQ2RmKMSFS1S9E5H9wT3KLAWqBa3EPwDnBK9uG66cAuAz4s5cA1gFXePMvBR4WkTu8bUzvwF/DmJDZ3VyNOUQiUqGqPfyOw5j2Zk1MxhhjgrIahDHGmKCsBmGMMSYoSxDGGGOCsgRhjDEmKEsQxhhjgrIEYYwxJqj/DwcgXwgLwaYCAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* background: */\n",
       "    progress::-webkit-progress-bar {background-color: #CDCDCD; width: 100%;}\n",
       "    progress {background-color: #CDCDCD;}\n",
       "\n",
       "    /* value: */\n",
       "    progress::-webkit-progress-value {background-color: #00BFFF  !important;}\n",
       "    progress::-moz-progress-bar {background-color: #00BFFF  !important;}\n",
       "    progress {color: #00BFFF ;}\n",
       "\n",
       "    /* optional */\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #000000;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      <progress value='8' class='progress-bar-interrupted' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      80.00% [8/10] [01:35<00:23]\n",
       "      <br>\n",
       "      ████████████████████100.00% [79/79] [val_loss=0.0731, val_acc=0.9795]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31m<<<<<< val_acc without improvement in 3 epoch,early stopping >>>>>> \n",
      "\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>train_acc</th>\n",
       "      <th>lr</th>\n",
       "      <th>val_loss</th>\n",
       "      <th>val_acc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0.327403</td>\n",
       "      <td>0.893783</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.105610</td>\n",
       "      <td>0.9660</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.109891</td>\n",
       "      <td>0.966483</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.086295</td>\n",
       "      <td>0.9746</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.092442</td>\n",
       "      <td>0.972733</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.058266</td>\n",
       "      <td>0.9825</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0.085473</td>\n",
       "      <td>0.975367</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.072749</td>\n",
       "      <td>0.9806</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0.079213</td>\n",
       "      <td>0.977350</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.059756</td>\n",
       "      <td>0.9832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>0.077976</td>\n",
       "      <td>0.977800</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.081202</td>\n",
       "      <td>0.9768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0.074797</td>\n",
       "      <td>0.978950</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.076534</td>\n",
       "      <td>0.9821</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0.072074</td>\n",
       "      <td>0.980133</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.073126</td>\n",
       "      <td>0.9795</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   epoch  train_loss  train_acc    lr  val_loss  val_acc\n",
       "0      1    0.327403   0.893783  0.01  0.105610   0.9660\n",
       "1      2    0.109891   0.966483  0.01  0.086295   0.9746\n",
       "2      3    0.092442   0.972733  0.01  0.058266   0.9825\n",
       "3      4    0.085473   0.975367  0.01  0.072749   0.9806\n",
       "4      5    0.079213   0.977350  0.01  0.059756   0.9832\n",
       "5      6    0.077976   0.977800  0.01  0.081202   0.9768\n",
       "6      7    0.074797   0.978950  0.01  0.076534   0.9821\n",
       "7      8    0.072074   0.980133  0.01  0.073126   0.9795"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torchkeras import KerasModel \n",
    "from torchmetrics import Accuracy\n",
    "\n",
    "net = create_net() \n",
    "model = KerasModel(net,\n",
    "                   loss_fn=nn.CrossEntropyLoss(),\n",
    "                   metrics_dict = {\"acc\":Accuracy(task='multiclass',num_classes=10)},\n",
    "                   optimizer = torch.optim.Adam(net.parameters(),lr = 0.01)  )\n",
    "\n",
    "model.fit(\n",
    "    train_data = dl_train,\n",
    "    val_data= dl_val,\n",
    "    epochs=10,\n",
    "    patience=3,\n",
    "    monitor=\"val_acc\", \n",
    "    mode=\"max\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
